{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Introduction to Deep Learning with PyTorch\n",
    "\n",
    "In this notebook, we will go over the basics of Deep Learning using Pytorch. Pytorch is a framework which allows manipulations of tensors using constructs similar to numpy. In addition it has modules to allow auto differentiation to carryout back-propagation which forms the backbone of training a neural network\n",
    "\n",
    "## What are Neural Networks\n",
    "\n",
    "Deep learning is based on Artificial Neural networks which are made up of neurons. A neuron takes inputs, calculates the weighted sum and then passes the sum through some kind of non-linear function (called activation function) as shown below:\n",
    "\n",
    "![Neuron](./images/neuron.png \"Neuron\")\n",
    "\n",
    "<br/>\n",
    "<br/>\n",
    "\n",
    "We stack these neurons to make a neural network as shown below:\n",
    "![Neuron](./images/nn.svg \"Neuron\")\n",
    "\n",
    "Let us now create a neural network in PyTorch. We will use this network to train a model to take MNIST data as input and produce the class it belongs to.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "from torch.optim.lr_scheduler import StepLR\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### MNIST \n",
    "\n",
    "MNIST dataset has images 28x28 pixels  = 784 pixels. \n",
    "\n",
    "We will have 10 units at the output layer to signify the digit (0-9) the image belongs to.\n",
    "\n",
    "Let us first load the data and print some images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define a transform to normalize the data\n",
    "transform = transforms.Compose([transforms.ToTensor(),\n",
    "                              transforms.Normalize(0.5, 0.5),\n",
    "                             ])\n",
    "# Download and load the training data\n",
    "trainset = datasets.MNIST('MNIST_data/', download=True, train=True, transform=transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)\n",
    "\n",
    "# Download and load the test data\n",
    "testset = datasets.MNIST('MNIST_data/', download=True, train=False, transform=transform)\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# let us load a batch of training data and checkout its shape\n",
    "dataiter = iter(trainloader)\n",
    "images, labels = next(dataiter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([128, 1, 28, 28])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "images.size()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "image size of `(32,1,28,28)` means that we have 32 images, with each image of size (1x28x28) (channels x height x width)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7f36243383d0>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAauUlEQVR4nO3df2zU9R3H8dfxowdqe6zW9nryqwWVRX4YUboGQRgdpVsMKDHqNMJiJGAxA/yVLlNwW1LHsulcKu6Phc5N8Ec2YJqNRIst+1EwgIyZbQ2t3VoDLYOld6XYQtrP/iDeOGmB73HXd+94PpJPYu++n97b706eu97xrc855wQAwCAbZj0AAODKRIAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAICJEdYDfFFfX5+OHDmizMxM+Xw+63EAAB4559TZ2alQKKRhwwZ+nTPkAnTkyBGNGzfOegwAwGVqbW3V2LFjB7x/yP0ILjMz03oEAEACXOzP86QFqKqqShMnTtSoUaNUVFSkDz/88JL28WM3AEgPF/vzPCkBevPNN7Vu3TqtX79eBw4c0IwZM1RaWqpjx44l4+EAAKnIJcGsWbNceXl59Ove3l4XCoVcZWXlRfeGw2EnicVisVgpvsLh8AX/vE/4K6DTp09r//79Kikpid42bNgwlZSUqL6+/rzje3p6FIlEYhYAIP0lPEDHjx9Xb2+v8vLyYm7Py8tTW1vbecdXVlYqEAhEF5+AA4Arg/mn4CoqKhQOh6OrtbXVeiQAwCBI+N8DysnJ0fDhw9Xe3h5ze3t7u4LB4HnH+/1++f3+RI8BABjiEv4KKCMjQzNnzlRNTU30tr6+PtXU1Ki4uDjRDwcASFFJuRLCunXrtGzZMt12222aNWuWXnrpJXV1delb3/pWMh4OAJCCkhKg++67T//5z3/03HPPqa2tTbfccot27tx53gcTAABXLp9zzlkPca5IJKJAIGA9BgDgMoXDYWVlZQ14v/mn4AAAVyYCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADAxAjrAYCLGT16tOc9P/3pT+N6rIcfftjznq1bt3reU15e7nnPqVOnPO8BhjJeAQEATBAgAICJhAdow4YN8vl8MWvKlCmJfhgAQIpLyntAN998s95///3/P8gI3moCAMRKShlGjBihYDCYjG8NAEgTSXkP6PDhwwqFQiosLNSDDz6olpaWAY/t6elRJBKJWQCA9JfwABUVFam6ulo7d+7Upk2b1NzcrDlz5qizs7Pf4ysrKxUIBKJr3LhxiR4JADAEJTxAZWVluvfeezV9+nSVlpbq97//vTo6OvTWW2/1e3xFRYXC4XB0tba2JnokAMAQlPRPB4wZM0Y33nijGhsb+73f7/fL7/cnewwAwBCT9L8HdPLkSTU1NSk/Pz/ZDwUASCEJD9CTTz6puro6/etf/9Jf/vIX3X333Ro+fLgeeOCBRD8UACCFJfxHcJ9++qkeeOABnThxQtddd53uuOMO7dmzR9ddd12iHwoAkMISHqA33ngj0d8SaSSeC4v+9a9/9bynsLDQ8x5J+uSTTzzveeihhzzv8fl8nvcsX77c8x5gKONacAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACAiaT/QjrgXHfeeafnPZMnT/a854knnvC8R5JefPFFz3uampo87/nvf//rec9tt93mec++ffs87wEGC6+AAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIKrYWNQPfPMM4PyONXV1YPyOJL0xz/+0fOetWvXet5z6623et4zb948z3uAwcIrIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABBcjxaAKBAKe9xw8eNDznnA47HlPvB577DHPe2bNmpWESYDUwisgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEFyPFoPL5fJ731NfXe97T19fneU+88vPzPe8ZP3685z0jRvCfK9ILr4AAACYIEADAhOcA7d69W3fddZdCoZB8Pp+2b98ec79zTs8995zy8/M1evRolZSU6PDhw4maFwCQJjwHqKurSzNmzFBVVVW/92/cuFEvv/yyXn31Ve3du1dXX321SktL1d3dfdnDAgDSh+d3NcvKylRWVtbvfc45vfTSS/rud7+rxYsXS5Jee+015eXlafv27br//vsvb1oAQNpI6HtAzc3NamtrU0lJSfS2QCCgoqKiAT/J1NPTo0gkErMAAOkvoQFqa2uTJOXl5cXcnpeXF73viyorKxUIBKJr3LhxiRwJADBEmX8KrqKiQuFwOLpaW1utRwIADIKEBigYDEqS2tvbY25vb2+P3vdFfr9fWVlZMQsAkP4SGqCCggIFg0HV1NREb4tEItq7d6+Ki4sT+VAAgBTn+VNwJ0+eVGNjY/Tr5uZmHTx4UNnZ2Ro/frzWrFmjH/zgB7rhhhtUUFCgZ599VqFQSEuWLEnk3ACAFOc5QPv27dP8+fOjX69bt06StGzZMlVXV+vpp59WV1eXVqxYoY6ODt1xxx3auXOnRo0albipAQApz+ecc9ZDnCsSiSgQCFiPgSQ5cOCA5z2TJk3yvGeg9xwvZvjw4Z73/O53v/O8Z968eZ73nDlzxvOenJwcz3skqbOzM659wLnC4fAF39c3/xQcAODKRIAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABOefx0DcDmeffZZz3viudr0wYMHPe+RpFOnTnneM23aNM97/vCHP3jeU1ZW5nnPQw895HmPJG3atCmufYAXvAICAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEz4nHPOeohzRSIRBQIB6zEwhLzyyiue9yxevDiux4rnYqQbNmzwvOc3v/mN5z2NjY2e9xw/ftzzHkm65ZZb4toHnCscDisrK2vA+3kFBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCY4GKkQIrYvHmz5z0PP/xwXI/1ta99zfOeXbt2xfVYSF9cjBQAMCQRIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACZGWA8A4NJ88MEHnvcsW7Ysrse66qqr4toHeMErIACACQIEADDhOUC7d+/WXXfdpVAoJJ/Pp+3bt8fcv3z5cvl8vpi1aNGiRM0LAEgTngPU1dWlGTNmqKqqasBjFi1apKNHj0bX1q1bL2tIAED68fwhhLKyMpWVlV3wGL/fr2AwGPdQAID0l5T3gGpra5Wbm6ubbrpJq1at0okTJwY8tqenR5FIJGYBANJfwgO0aNEivfbaa6qpqdEPf/hD1dXVqaysTL29vf0eX1lZqUAgEF3jxo1L9EgAgCEo4X8P6P7774/+87Rp0zR9+nRNmjRJtbW1WrBgwXnHV1RUaN26ddGvI5EIEQKAK0DSP4ZdWFionJwcNTY29nu/3+9XVlZWzAIApL+kB+jTTz/ViRMnlJ+fn+yHAgCkEM8/gjt58mTMq5nm5mYdPHhQ2dnZys7O1vPPP6+lS5cqGAyqqalJTz/9tCZPnqzS0tKEDg4ASG2eA7Rv3z7Nnz8/+vXn798sW7ZMmzZt0qFDh/TLX/5SHR0dCoVCWrhwob7//e/L7/cnbmoAQMrzOeec9RDnikQiCgQC1mMAaaGlpSWufQcOHPC859wPIF2q7u5uz3uQOsLh8AXf1+dacAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADCR8F/JDWDoqKysjGtfVVWV5z3xXA27urra8x6kD14BAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmuBgpkMbivdjnCy+84HnPzTffHNdj4crFKyAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQXIwXS2GeffRbXvmPHjnnek5OTE9dj4crFKyAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQXIwWQEJ988on1CEgxvAICAJggQAAAE54CVFlZqdtvv12ZmZnKzc3VkiVL1NDQEHNMd3e3ysvLde211+qaa67R0qVL1d7entChAQCpz1OA6urqVF5erj179ui9997TmTNntHDhQnV1dUWPWbt2rd555x29/fbbqqur05EjR3TPPfckfHAAQGrz9CGEnTt3xnxdXV2t3Nxc7d+/X3PnzlU4HNYvfvELbdmyRV/96lclSZs3b9aXv/xl7dmzR1/5ylcSNzkAIKVd1ntA4XBYkpSdnS1J2r9/v86cOaOSkpLoMVOmTNH48eNVX1/f7/fo6elRJBKJWQCA9Bd3gPr6+rRmzRrNnj1bU6dOlSS1tbUpIyNDY8aMiTk2Ly9PbW1t/X6fyspKBQKB6Bo3bly8IwEAUkjcASovL9fHH3+sN95447IGqKioUDgcjq7W1tbL+n4AgNQQ119EXb16td59913t3r1bY8eOjd4eDAZ1+vRpdXR0xLwKam9vVzAY7Pd7+f1++f3+eMYAAKQwT6+AnHNavXq1tm3bpl27dqmgoCDm/pkzZ2rkyJGqqamJ3tbQ0KCWlhYVFxcnZmIAQFrw9AqovLxcW7Zs0Y4dO5SZmRl9XycQCGj06NEKBAJ65JFHtG7dOmVnZysrK0uPP/64iouL+QQcACCGpwBt2rRJkjRv3ryY2zdv3qzly5dLkl588UUNGzZMS5cuVU9Pj0pLS/XKK68kZFgAQPrwFCDn3EWPGTVqlKqqqlRVVRX3UABsXcp/68Dl4lpwAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMBHXb0QFhrqMjIy49t17772e92RnZ3ve86tf/crzno6ODs974v09XIWFhZ73fPjhh3E9Fq5cvAICAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEz4nHPOeohzRSIRBQIB6zGQ4mbPnh3XvnguEjpx4kTPe3p7ez3vaW5u9rwnFAp53iNJf/vb3zzvmT9/vuc93d3dnvcgdYTDYWVlZQ14P6+AAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATXIwUOEdGRobnPfFc+PTHP/6x5z0XuqjjQEaMGOF5jyTNmTPH857W1ta4Hgvpi4uRAgCGJAIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABBcjBQAkBRcjBQAMSQQIAGDCU4AqKyt1++23KzMzU7m5uVqyZIkaGhpijpk3b558Pl/MWrlyZUKHBgCkPk8BqqurU3l5ufbs2aP33ntPZ86c0cKFC9XV1RVz3KOPPqqjR49G18aNGxM6NAAg9Xn6dYk7d+6M+bq6ulq5ubnav3+/5s6dG739qquuUjAYTMyEAIC0dFnvAYXDYUlSdnZ2zO2vv/66cnJyNHXqVFVUVOjUqVMDfo+enh5FIpGYBQC4Arg49fb2um984xtu9uzZMbf//Oc/dzt37nSHDh1yv/71r93111/v7r777gG/z/r1650kFovFYqXZCofDF+xI3AFauXKlmzBhgmttbb3gcTU1NU6Sa2xs7Pf+7u5uFw6Ho6u1tdX8pLFYLBbr8tfFAuTpPaDPrV69Wu+++652796tsWPHXvDYoqIiSVJjY6MmTZp03v1+v19+vz+eMQAAKcxTgJxzevzxx7Vt2zbV1taqoKDgonsOHjwoScrPz49rQABAevIUoPLycm3ZskU7duxQZmam2traJEmBQECjR49WU1OTtmzZoq9//eu69tprdejQIa1du1Zz587V9OnTk/IvAABIUV7e99EAP+fbvHmzc865lpYWN3fuXJedne38fr+bPHmye+qppy76c8BzhcNh859bslgsFuvy18X+7OdipACApOBipACAIYkAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYGLIBcg5Zz0CACABLvbn+ZALUGdnp/UIAIAEuNif5z43xF5y9PX16ciRI8rMzJTP54u5LxKJaNy4cWptbVVWVpbRhPY4D2dxHs7iPJzFeThrKJwH55w6OzsVCoU0bNjAr3NGDOJMl2TYsGEaO3bsBY/Jysq6op9gn+M8nMV5OIvzcBbn4Szr8xAIBC56zJD7ERwA4MpAgAAAJlIqQH6/X+vXr5ff77cexRTn4SzOw1mch7M4D2el0nkYch9CAABcGVLqFRAAIH0QIACACQIEADBBgAAAJlImQFVVVZo4caJGjRqloqIiffjhh9YjDboNGzbI5/PFrClTpliPlXS7d+/WXXfdpVAoJJ/Pp+3bt8fc75zTc889p/z8fI0ePVolJSU6fPiwzbBJdLHzsHz58vOeH4sWLbIZNkkqKyt1++23KzMzU7m5uVqyZIkaGhpijunu7lZ5ebmuvfZaXXPNNVq6dKna29uNJk6OSzkP8+bNO+/5sHLlSqOJ+5cSAXrzzTe1bt06rV+/XgcOHNCMGTNUWlqqY8eOWY826G6++WYdPXo0uv70pz9Zj5R0XV1dmjFjhqqqqvq9f+PGjXr55Zf16quvau/evbr66qtVWlqq7u7uQZ40uS52HiRp0aJFMc+PrVu3DuKEyVdXV6fy8nLt2bNH7733ns6cOaOFCxeqq6sreszatWv1zjvv6O2331ZdXZ2OHDmie+65x3DqxLuU8yBJjz76aMzzYePGjUYTD8ClgFmzZrny8vLo1729vS4UCrnKykrDqQbf+vXr3YwZM6zHMCXJbdu2Lfp1X1+fCwaD7kc/+lH0to6ODuf3+93WrVsNJhwcXzwPzjm3bNkyt3jxYpN5rBw7dsxJcnV1dc65s//bjxw50r399tvRY/7xj384Sa6+vt5qzKT74nlwzrk777zTffvb37Yb6hIM+VdAp0+f1v79+1VSUhK9bdiwYSopKVF9fb3hZDYOHz6sUCikwsJCPfjgg2ppabEeyVRzc7Pa2tpinh+BQEBFRUVX5POjtrZWubm5uummm7Rq1SqdOHHCeqSkCofDkqTs7GxJ0v79+3XmzJmY58OUKVM0fvz4tH4+fPE8fO71119XTk6Opk6dqoqKCp06dcpivAENuYuRftHx48fV29urvLy8mNvz8vL0z3/+02gqG0VFRaqurtZNN92ko0eP6vnnn9ecOXP08ccfKzMz03o8E21tbZLU7/Pj8/uuFIsWLdI999yjgoICNTU16Tvf+Y7KyspUX1+v4cOHW4+XcH19fVqzZo1mz56tqVOnSjr7fMjIyNCYMWNijk3n50N/50GSvvnNb2rChAkKhUI6dOiQnnnmGTU0NOi3v/2t4bSxhnyA8H9lZWXRf54+fbqKioo0YcIEvfXWW3rkkUcMJ8NQcP/990f/edq0aZo+fbomTZqk2tpaLViwwHCy5CgvL9fHH398RbwPeiEDnYcVK1ZE/3natGnKz8/XggUL1NTUpEmTJg32mP0a8j+Cy8nJ0fDhw8/7FEt7e7uCwaDRVEPDmDFjdOONN6qxsdF6FDOfPwd4fpyvsLBQOTk5afn8WL16td5991198MEHMb++JRgM6vTp0+ro6Ig5Pl2fDwOdh/4UFRVJ0pB6Pgz5AGVkZGjmzJmqqamJ3tbX16eamhoVFxcbTmbv5MmTampqUn5+vvUoZgoKChQMBmOeH5FIRHv37r3inx+ffvqpTpw4kVbPD+ecVq9erW3btmnXrl0qKCiIuX/mzJkaOXJkzPOhoaFBLS0tafV8uNh56M/BgwclaWg9H6w/BXEp3njjDef3+111dbX7+9//7lasWOHGjBnj2trarEcbVE888YSrra11zc3N7s9//rMrKSlxOTk57tixY9ajJVVnZ6f76KOP3EcffeQkuZ/85Cfuo48+cv/+97+dc8698MILbsyYMW7Hjh3u0KFDbvHixa6goMB99tlnxpMn1oXOQ2dnp3vyySddfX29a25udu+//7679dZb3Q033OC6u7utR0+YVatWuUAg4Gpra93Ro0ej69SpU9FjVq5c6caPH+927drl9u3b54qLi11xcbHh1Il3sfPQ2Njovve977l9+/a55uZmt2PHDldYWOjmzp1rPHmslAiQc8797Gc/c+PHj3cZGRlu1qxZbs+ePdYjDbr77rvP5efnu4yMDHf99de7++67zzU2NlqPlXQffPCBk3TeWrZsmXPu7Eexn332WZeXl+f8fr9bsGCBa2hosB06CS50Hk6dOuUWLlzorrvuOjdy5Eg3YcIE9+ijj6bd/0nr799fktu8eXP0mM8++8w99thj7ktf+pK76qqr3N133+2OHj1qN3QSXOw8tLS0uLlz57rs7Gzn9/vd5MmT3VNPPeXC4bDt4F/Ar2MAAJgY8u8BAQDSEwECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABg4n9Aoo0Ee2nJJQAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Let us plot one image\n",
    "plt.imshow(images[10].numpy().squeeze(), cmap='Greys_r')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "NN(\n",
       "  (fc1): Linear(in_features=784, out_features=192, bias=True)\n",
       "  (fc2): Linear(in_features=192, out_features=128, bias=True)\n",
       "  (fc3): Linear(in_features=128, out_features=10, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class NN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.fc1 = nn.Linear(784, 192)\n",
    "        self.fc2 = nn.Linear(192, 128)\n",
    "        self.fc3 = nn.Linear(128, 10)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        ''' Forward pass through the network, returns the output logits '''        \n",
    "        x = self.fc1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.fc2(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.fc3(x)\n",
    "        \n",
    "        return x\n",
    "    \n",
    "    def predict(self, x):\n",
    "        ''' To predict classes by calculating the softmax '''\n",
    "        logits = self.forward(x)\n",
    "        return F.softmax(logits, dim=1)\n",
    "\n",
    "model = NN()\n",
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Forward pass and Calculate Cross Entropy Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# function to view the probability of classification of digit\n",
    "def view_classification(img, probs):\n",
    "    probs = probs.data.numpy().squeeze()\n",
    "    fig, (ax1, ax2) = plt.subplots(figsize=(6,7), ncols=2)\n",
    "    ax1.imshow(img.numpy().squeeze())\n",
    "    ax1.axis('off')\n",
    "    ax2.barh(np.arange(10), probs)\n",
    "    ax2.set_aspect(0.1)\n",
    "    ax2.set_yticks(np.arange(10))\n",
    "    ax2.set_yticklabels(np.arange(10).astype(int), size='large');\n",
    "    ax2.set_title('Probability')\n",
    "    ax2.set_xlim(0, 1.1)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeQAAAERCAYAAACq8dRTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAmNUlEQVR4nO3deVxU9f4/8NcwyGFzRgU3dNwFU7TUUvGq6M2foF6xvCqYBpaGLWo+tCwqJTQD07TFJb3XcAnTckO9Kko3shJ3WxANMzFAJSNlUHAQ+Pz+6OvcRmYOisA5Hl7Px2MeNZ/lzKsT+PZzljk6IYQAERERKcpJ6QBERETEgkxERKQKLMhEREQqwIJMRESkAizIREREKsCCTEREpAIsyERERCrAgkxERKQCLMhEREQqwIJMRKQgnU6HyZMnV9n2Vq9eDZ1Oh6NHj1Y4tn///ujfv7/1fWZmJnQ6HVavXm1te/PNN6HT6aosHznGgkxEZMetwnbr5erqCl9fX0yePBm5ublKx1PU22+/jW3btikdQ3NYkImIZMyZMwfr1q3DkiVL0Lt3byxfvhwBAQEoLCxUOto927t3L/bu3Ss75o033kBRUZFNGwty9XBWOgARkZoNHjwYDz/8MABg4sSJ8PLywqJFi5CYmIgxY8aUG3/9+nV4eHjUdMxKcXFxqXCMs7MznJ1ZKmoCV8hERHfh73//OwDg3LlzGD9+PDw9PXH27FkMGTIEdevWxdixYwH8WZhnzJgBk8kESZLg5+eHhQsXwtED9hISEuDn5wdXV1d0794d+/fvt+k/f/48nn/+efj5+cHNzQ1eXl4YNWoUMjMz7W6vsLAQkyZNgpeXFwwGA8LDw3HlyhWbMbefQ7bn9nPIOp0O169fx5o1a6yH88ePH48vv/wSOp0OW7duLbeN9evXQ6fTITU1Vfazajv+tYeI6C6cPXsWAODl5QUAKCkpQVBQEPr06YOFCxfC3d0dQgiEhITgyy+/xIQJE/DQQw8hKSkJL7/8MnJycrB48WKbbX711VfYuHEjpk6dCkmSsGzZMgQHB+Pw4cPw9/cHABw5cgQHDhxAWFgYmjdvjszMTCxfvhz9+/dHeno63N3dbbY5efJk1KtXD2+++SZ++uknLF++HOfPn0dKSso9XaS1bt06TJw4ET169EBkZCQAoG3btujVqxdMJhMSEhLw+OOP28xJSEhA27ZtERAQUOnPrRUEERGVEx8fLwCI5ORkcfnyZZGVlSU2bNggvLy8hJubm8jOzhYRERECgHj11Vdt5m7btk0AEG+99ZZN+8iRI4VOpxM///yztQ2AACCOHj1qbTt//rxwdXUVjz/+uLWtsLCwXMbU1FQBQKxdu7Zc7u7du4vi4mJr+zvvvCMAiMTERGtbYGCgCAwMtL4/d+6cACDi4+OtbdHR0eL2UuHh4SEiIiLK5YmKihKSJImrV69a23777Tfh7OwsoqOjy40nWzxkTUQkY+DAgWjYsCFMJhPCwsLg6emJrVu3olmzZtYxzz33nM2cXbt2Qa/XY+rUqTbtM2bMgBACu3fvtmkPCAhA9+7dre9btGiB4cOHIykpCaWlpQAANzc3a//NmzeRl5eHdu3aoV69ejh+/Hi53JGRkahTp45NRmdnZ+zatasSe+HOhIeHw2KxYNOmTda2jRs3oqSkBOPGjau2z9UKHrImIpKxdOlS+Pr6wtnZGY0bN4afnx+cnP63lnF2dkbz5s1t5pw/fx4+Pj6oW7euTfsDDzxg7f+r9u3bl/tcX19fFBYW4vLly2jSpAmKiooQGxuL+Ph45OTk2JyLzs/PLzf/9m16enqiadOmDs85V4UOHTrgkUceQUJCAiZMmADgz8PVvXr1Qrt27artc7WCBZmISEaPHj2sV1nbI0mSTYGuLlOmTEF8fDymTZuGgIAAGI1G6HQ6hIWFoaysrNo//06Fh4fjxRdfRHZ2NiwWCw4ePIglS5YoHeu+wIJMRFTFWrZsieTkZBQUFNiskk+fPm3t/6szZ86U20ZGRgbc3d3RsGFDAMCmTZsQERGBd9991zrmxo0buHr1qt0MZ86cwYABA6zvr127hosXL2LIkCGV/u+6Re6isLCwMEyfPh2ffvopioqKUKdOHYSGht7zZ9YGPIdMRFTFhgwZgtLS0nIrw8WLF0On02Hw4ME27ampqTbngbOyspCYmIhBgwZBr9cDAPR6fblbpj788EPrOebbrVy5Ejdv3rS+X758OUpKSsp9dmV4eHg4/IuAt7c3Bg8ejE8++QQJCQkIDg6Gt7f3PX9mbcAVMhFRFRs2bBgGDBiA119/HZmZmXjwwQexd+9eJCYmYtq0aWjbtq3NeH9/fwQFBdnc9gQAMTEx1jH/+Mc/sG7dOhiNRnTs2BGpqalITk623n51u+LiYjz66KMYPXo0fvrpJyxbtgx9+vRBSEjIPf/3de/eHcnJyVi0aBF8fHzQunVr9OzZ09ofHh6OkSNHAgDmzp17z59XW7AgExFVMScnJ2zfvh2zZ8/Gxo0bER8fj1atWmHBggWYMWNGufGBgYEICAhATEwMfv31V3Ts2BGrV69Gly5drGPef/996PV6JCQk4MaNG/jb3/6G5ORkBAUF2c2wZMkSJCQkYPbs2bh58ybGjBmDDz74oEoeFLFo0SJERkZav1YzIiLCpiAPGzYM9evXR1lZWZX8BaC20Inbj4EQERHdg5KSEvj4+GDYsGFYtWqV0nHuGzyHTEREVWrbtm24fPkywsPDlY5yX+EKmYiIqsShQ4fwww8/YO7cufD29rb7hSXkGFfIRERUJZYvX47nnnsOjRo1wtq1a5WOc9/hCpmIiEgFuEImIiJSgTu+7en/OY2qzhxEtcq+ss+VjkBEKsP7kInIrrKyMly4cAF169atkntXiWorIQQKCgrg4+Mj+73nLMhEZNeFCxdgMpmUjkGkGVlZWeWeDPZXLMhEZNethyJkZWXBYDAonIbo/mU2m2Eymco9jvN2LMhEZNetw9QGg4EFmagKVHTqh1dZExERqQALMhERkQqwIBMREakACzIREZEKsCATERGpAAsyERGRCrAgExERqQDvQyYiWf7RSXCS3Ktse5lxQ6tsW0RawhUykcYcO3YMwcHBMBgMqFu3LgYNGoTvvvtO6VhEVAGukKlKFIT2ku1fFve+w76ZYyNl5+oOfF+pTLXR8ePH0adPH5hMJkRHR6OsrAzLli1DYGAgDh8+DD8/P6UjEpEDLMhEGjJr1iy4ubkhNTUVXl5eAIBx48bB19cXr732GjZv3qxwQiJyhIesiTTk66+/xsCBA63FGACaNm2KwMBA7Ny5E9euXVMwHRHJYUEm0hCLxQI3N7dy7e7u7iguLkZaWpoCqYjoTvCQNZGG+Pn54eDBgygtLYVerwcAFBcX49ChQwCAnJwch3MtFgssFov1vdlsrt6wRGSDK2QiDXn++eeRkZGBCRMmID09HWlpaQgPD8fFixcBAEVFRQ7nxsbGwmg0Wl8mk6mmYhMRWJCJNOXZZ5/Fa6+9hvXr16NTp07o3Lkzzp49i5kzZwIAPD09Hc6NiopCfn6+9ZWVlVVTsYkILMhEmjNv3jzk5ubi66+/xg8//IAjR46grKwMAODr6+twniRJMBgMNi8iqjk8h0x3TO/b1mHfjLnrZed2dqnjsC+vs/y3QHkfkM9F5dWvXx99+vSxvk9OTkbz5s3RoUMHBVMRkRyukIk0buPGjThy5AimTZsGJyf+yhOpFVfIRBqyf/9+zJkzB4MGDYKXlxcOHjyI+Ph4BAcH48UXX1Q6HhHJYEEm0pBmzZpBr9djwYIFKCgoQOvWrfHWW29h+vTpcHbmrzuRmvE3lEhD2rZti6SkJKVjEFElsCATkay0mCBecU1UA3iFBxERkQpwhUx37GxEI4d9j3lclZ3bbo/jRyz6rToiO1fI9hIRaQNXyERERCrAFTIRyfKPToKTJP/lLXcjM25olW2LSEu4QiYiIlIBFmQiDTpz5gzCwsLQvHlzuLu7o0OHDpgzZw4KCwuVjkZEDvCQNZHGZGVloUePHjAajZg8eTIaNGiA1NRUREdH49ixY0hMTFQ6IhHZwYJMpDHr1q3D1atX8c0336BTp04AgMjISJSVlWHt2rW4cuUK6tevr3BKIrodD1kTaYzZbAYANG7c2Ka9adOmcHJygouLixKxiKgCXCHXMs7NfBz2XY+X/4N6X4cFDvu6LJkpO9dvwWGHfaKkRHYu3Z3+/ftj/vz5mDBhAmJiYuDl5YUDBw5g+fLlmDp1Kjw8POzOs1gssFgs1ve3CjsR1QyukIk0Jjg4GHPnzsW+ffvQtWtXtGjRAmFhYZgyZQoWL17scF5sbCyMRqP1ZTKZajA1EbEgE2lQq1at0K9fP6xcuRKbN2/G008/jbfffhtLlixxOCcqKgr5+fnWV1ZWVg0mJiIesibSmA0bNiAyMhIZGRlo3rw5AGDEiBEoKyvDK6+8gjFjxsDLy6vcPEmSIElSTcclov/DFTKRxixbtgxdu3a1FuNbQkJCUFhYiBMnTiiUjIjksCATaUxubi5KS0vLtd+8eRMAUMKL6IhUiQWZSGN8fX1x4sQJZGRk2LR/+umncHJyQpcuXRRKRkRyeA65ljkzpaXDvlOdlsrO9UuZ4rCvbewB2bl8hGLNefnll7F792707dsXkydPhpeXF3bu3Indu3dj4sSJ8PFxfOsbESmHBZlIY/r164cDBw7gzTffxLJly5CXl4fWrVtj3rx5mDlT/n5xIlIOCzKRBvXo0QO7du2qkm2lxQTBYDBUybaIyDGeQyYiIlIBFmQiIiIVYEEmIiJSAZ5DJiJZ/tFJcJLc73peZtzQakhDpF1cIRMREakAV8gaIwIelO3v2/9Hh30L//CTnev30kWHffzuJ/UYP3481qxZ47A/OzsbzZo1q8FERHQnWJCJNGbSpEkYOHCgTZsQAs8++yxatWrFYkykUizIRBoTEBCAgIAAm7ZvvvkGhYWFGDt2rEKpiKgiPIdMVAusX78eOp0OTzzxhNJRiMgBFmQijbt58yY+++wz9O7dG61atVI6DhE5wEPWRBqXlJSEvLy8Cg9XWywWWCwW63uz2Vzd0YjoL7hCJtK49evXo06dOhg9erTsuNjYWBiNRuvLZDLVUEIiAliQiTTt2rVrSExMRFBQELy8vGTHRkVFIT8/3/rKysqqoZREBPCQ9X1HJ0my/e0/OC3b/3rjLxz2jX5xhuxc94uHZPtJfbZt23bHV1dLkgSpgp8vIqo+XCETaVhCQgI8PT0REhKidBQiqgALMpFGXb58GcnJyXj88cfh7n7330VNRDWLBZlIozZu3IiSkhJ+GQjRfYIFmUijEhIS0KhRo3Jfo0lE6sSLuog0KjU1tUq2kxYTBIPBUCXbIiLHuEImIiJSAa6Q7zPnZneT7f+Pz1LZ/sisYId9nnvTZOeWyfYSEdG94AqZiIhIBbhCJiJZ/tFJcJIqd9tUZtzQKk5DpF1cIRMREakACzKRRh0/fhwhISFo0KAB3N3d4e/vjw8++EDpWETkAA9ZE2nQ3r17MWzYMHTt2hWzZs2Cp6cnzp49i+zsbKWjEZEDLMhEGmM2mxEeHo6hQ4di06ZNcHLigTCi+wF/U4k0Zv369cjNzcW8efPg5OSE69evo6yMN60RqR1XyCqk92rgsO9fY5bLzrWIEtn+C8+2cNhXdj1dPhjdF5KTk2EwGJCTk4PHHnsMGRkZ8PDwwJNPPonFixfD1dVV6YhEZAdXyEQac+bMGZSUlGD48OEICgrC5s2b8fTTT+Ojjz7CU0895XCexWKB2Wy2eRFRzeEKmUhjrl27hsLCQjz77LPWq6pHjBiB4uJirFixAnPmzEH79u3LzYuNjUVMTExNxyWi/8MVMpHGuLm5AQDGjBlj0/7EE08AcPzQiaioKOTn51tfWVlZ1RuUiGxwhUykMT4+Pjh58iQaN25s096oUSMAwJUrV+zOkyQJkiRVez4iso8rZCKN6d69OwAgJyfHpv3ChQsAgIYNG9Z4JiKqGAsykcaMHj0aALBq1Sqb9n//+99wdnZG//79FUhFRBXhIWsV+n2tl8O+v0ny95M+sPZF2f7W31XNQ+tJvbp27Yqnn34aH3/8MUpKShAYGIiUlBR8/vnniIqKgo+Pj9IRicgOFmQiDfroo4/QokULxMfHY+vWrWjZsiUWL16MadOmKR2NiBxgQSbSoDp16iA6OhrR0dFKRyGiO8SCTESy0mKCYDAYlI5BpHm8qIuIiEgFWJCJiIhUgIesiUiWf3QSnCT3e9pGZtzQKkpDpF1cIRMREakAV8gKyJ3aW7b/wIPvOex7PidQdm7bt9Nk+/lUXO1LSUnBgAED7PalpqaiV69eNZyIiO4ECzKRRk2dOhWPPPKITVu7du0USkNEFWFBJtKovn37YuTIkUrHIKI7xHPIRBpWUFCAkpISpWMQ0R1gQSbSqKeeegoGgwGurq4YMGAAjh49qnQkIpLBQ9ZEGuPi4oJ//vOfGDJkCLy9vZGeno6FCxeib9++OHDgALp27Wp3nsVigcVisb43m801FZmIwIJMpDm9e/dG797/u5I/JCQEI0eORJcuXRAVFYU9e/bYnRcbG4uYmJiaiklEt2FBrial/bs57Pto2oeycyWd4/8tp+d2lp3rWnBYPhjVSu3atcPw4cOxZcsWlJaWQq/XlxsTFRWF6dOnW9+bzWaYTKaajElUq7EgE9USJpMJxcXFuH79ut2HRUiSBEmSFEhGRAAv6iKqNX755Re4urrC09NT6ShEZAcLMpHGXL58uVzb999/j+3bt2PQoEFwcuKvPZEa8ZA1kcaEhobCzc0NvXv3RqNGjZCeno6VK1fC3d0dcXFxSscjIgdYkIk05rHHHkNCQgIWLVoEs9mMhg0bYsSIEYiOjuZXZxKpGAsykcZMnToVU6dOVToGEd0lFmQikpUWE2T3qmwiqlosyNUk8x8uDvsekXSyczvGv+Cwr9WO1Epnqojeq4Fsf4mf43tSndMzZeeWXs2vTCQiolqDl1sSERGpAAsyERGRCvCQNRHJ8o9OgpPkXi3bzowbWi3bJbofcYVMRESkAizIRBo3b9486HQ6+Pv7Kx2FiGSwIBNpWHZ2Nt5++214eHgoHYWIKsBzyEQa9tJLL6FXr14oLS3F77//rnQcIpLBglxJZX0eku2fH7LeYd+eIvkLZNouTHfYJ7p0kJ17Nqy+bL++2PE90GNH/Fd2bpRXssO+Fy8EyM4984hsN1WD/fv3Y9OmTThx4gSmTJmidBwiqgALMpEGlZaWYsqUKZg4cSI6d+58R3MsFgssFov1vdlsrq54RGQHCzKRBn300Uc4f/48kpMdH9W4XWxsLGJiYqoxFRHJ4UVdRBqTl5eH2bNnY9asWWjYsOEdz4uKikJ+fr71lZWVVY0pieh2XCETacwbb7yBBg0a3PV5Y0mSIElSNaUiooqwIBNpyJkzZ7By5Uq89957uHDhgrX9xo0buHnzJjIzM2EwGNCggfyDRIio5vGQNZGG5OTkoKysDFOnTkXr1q2tr0OHDiEjIwOtW7fGnDlzlI5JRHZwhVxJl3rJ37r0mMdVh32v/dZNdu6vHzd32Lfr4RWyc5vpq+c7hyvygMcF2f4z8KqhJLWbv78/tm7dWq79jTfeQEFBAd5//320bdtWgWREVBEWZCIN8fb2xmOPPVau/b333gMAu31EpA48ZE1ERKQCXCET1QIpKSmVnpsWEwSDwVB1YYjILq6QiYiIVIAFmYiISAVYkImIiFSA55CJSJZ/dBKcpHu/nS4zbmgVpCHSLhbkSvL4+2+Vnvt2o+OV7h9+ZpTs3MauBbL9B3Z1cdjX5hP5e4l1H1sc9g1ulCY7l4iI5PGQNZHGnDx5EqNGjUKbNm3g7u4Ob29v9OvXDzt27FA6GhHJ4AqZSGPOnz+PgoICREREwMfHB4WFhdi8eTNCQkKwYsUKREZGKh2RiOxgQSbSmCFDhmDIkCE2bZMnT0b37t2xaNEiFmQileIha6JaQK/Xw2Qy4erVq0pHISIHuEIm0qjr16+jqKgI+fn52L59O3bv3o3Q0FClYxGRAyzIRBo1Y8YMrFjx59PBnJycMGLECCxZssTheIvFAovlf1fSm83mas9IRP/DglxJOp2otm13OzLWYV/z6UWyc7N/uSjb36bJWYd9P09pIzs3vf1Sh32+G5+XndsOB2X7qepNmzYNI0eOxIULF/DZZ5+htLQUxcXFDsfHxsYiJiamBhMS0V/xHDKRRnXo0AEDBw5EeHg4du7ciWvXrmHYsGEQwv5fJqOiopCfn299ZWVl1XBiotqNBZmolhg5ciSOHDmCjIwMu/2SJMFgMNi8iKjmsCAT1RJFRX+e7sjPz1c4CRHZw4JMpDG//Vb+a11v3ryJtWvXws3NDR07dlQgFRFVhBd1EWnMpEmTYDab0a9fPzRr1gyXLl1CQkICTp8+jXfffReenp5KRyQiO1iQiTQmNDQUq1atwvLly5GXl4e6deuie/fumD9/PkJCQpSOR0QOsCATaUxYWBjCwsKqbHtpMUG8wIuoBrAgV9LvPzaSH+D4KYcVul4oOew780w92bmlbk1k+78ZsdBhXyO9/DNv5e41bjed9xkTEd0LXtRFRESkAizIREREKsBD1kQkyz86CU6S/OkMOZlxQ6swDZF2cYVMRESkAizIRBpz5MgRTJ48GZ06dYKHhwdatGiB0aNHO/zKTCJSBx6yJtKY+fPn49tvv8WoUaPQpUsXXLp0CUuWLEG3bt1w8OBB+Pv7Kx2RiOxgQa6k5v8tkR/g+AmKFYrqusdh34T+l2Tnlooy2f4fi+s47Ov36Quyc9u9kirbT+owffp0rF+/Hi4uLta20NBQdO7cGXFxcfjkk08UTEdEjrAgE2lM7969y7W1b98enTp1wqlTpxRIRER3gueQiWoBIQRyc3Ph7e2tdBQicoAFmagWSEhIQE5ODkJDQx2OsVgsMJvNNi8iqjksyEQad/r0abzwwgsICAhARESEw3GxsbEwGo3Wl8lkqsGURMSCTKRhly5dwtChQ2E0GrFp0ybo9XqHY6OiopCfn299ZWVl1WBSIuJFXUQalZ+fj8GDB+Pq1av4+uuv4ePjIztekiRIkuMHmxBR9WJBJtKgGzduYNiwYcjIyEBycjI6duyodCQiqgALciW5fp0u29/hE8f39NY7Lb/tWa+tcdhX0X3Gcp8LAD7flDrsa/ffH2Tnyn8yqUVpaSlCQ0ORmpqKxMREBAQEKB2JiO4ACzKRxsyYMQPbt2/HsGHD8Mcff5T7IpBx48YplIyI5LAgE2nMd999BwDYsWMHduzYUa6fBZlInViQiTQmJSWlSreXFhMEg8FQpdskovJ42xMREZEKsCATERGpAAsyERGRCvAcMhHJ8o9OgpPkfs/byYwbWgVpiLSLBbmSyq5fl+1vcw/PDl4a7+u4r4K5bVD5z+V9xkREyuEhayKNuXbtGqKjoxEcHIwGDRpAp9Nh9erVSsciogqwIBNpzO+//445c+bg1KlTePDBB5WOQ0R3iIesiTSmadOmuHjxIpo0aYKjR4/ikUceUToSEd0BrpCJNEaSJDRp0kTpGER0l1iQiYiIVICHrIkIAGCxWGCxWKzvzWazgmmIah+ukIkIABAbGwuj0Wh9mUwmpSMR1SosyEQEAIiKikJ+fr71lZWVpXQkolqFh6yJCMCfF4NJkqR0DKJaiytkIiIiFWBBJiIiUgEesibSoCVLluDq1au4cOECAGDHjh3Izs4GAEyZMgVGo1HJeERkBwsykQYtXLgQ58+ft77fsmULtmzZAgAYN24cCzKRCrEgE2lQZmam0hGI6C6xIBORrLSYIBgMBqVjEGkeL+oiIiJSARZkIiIiFeAhayKS5R+dBCfJ/Z62kRk3tIrSEGkXV8hEREQqwIJMpEEWiwWvvPIKfHx84Obmhp49e2Lfvn1KxyIiGSzIRBo0fvx4LFq0CGPHjsX7778PvV6PIUOG4JtvvlE6GhE5wHPIRBpz+PBhbNiwAQsWLMBLL70EAAgPD4e/vz9mzpyJAwcOKJyQiOzhCplIYzZt2gS9Xo/IyEhrm6urKyZMmIDU1FQ+VpFIpViQiTTmxIkT8PX1LfdlHj169AAAfPfddwqkIqKK8JA1kcZcvHgRTZs2Ldd+q+3WAyduZ7FYYLFYrO/NZnP1BCQiu7hCJtKYoqIiSJJUrt3V1dXab09sbCyMRqP1ZTKZqjUnEdliQSbSGDc3N5uV7i03btyw9tsTFRWF/Px864vnmolqFg9ZE2lM06ZNkZOTU6794sWLAAAfHx+78yRJsruyJqKawRUykcY89NBDyMjIKHcO+NChQ9Z+IlIfFmQijRk5ciRKS0uxcuVKa5vFYkF8fDx69uzJc8NEKsVD1kQa07NnT4waNQpRUVH47bff0K5dO6xZswaZmZlYtWqV0vGIyAEWZCINWrt2LWbNmoV169bhypUr6NKlC3bu3Il+/fopHY2IHGBBJtIgV1dXLFiwAAsWLFA6ChHdIRZkIpKVFhNU7lu/iKjq8aIuIiIiFWBBJiIiUgEWZCIiIhVgQSYiIlIBFmQiIiIVYEEmIiJSARZkIiIiFeB9yERklxACAMo9pIKI7s6t36Fbv1OOsCATkV15eXkAwIdREFWRgoICGI1Gh/0syERkV4MGDQAAv/76q+wfImplNpthMpmQlZV1X37TGPMrqyrzCyFQUFDg8Fnkt9xxQd5X9vk9BSKi+4uT05+XmBiNxvvyD9RbDAYD8yuI+f90J3+p5UVdREREKsCCTEREpAIsyERklyRJiI6OhiRJSkepFOZXFvPfPZ2o6DpsIiIiqnZcIRMREakACzIREZEKsCATERGpAAsyERGRCrAgE9USS5cuRatWreDq6oqePXvi8OHDsuM///xzdOjQAa6urujcuTN27dpl0y+EwOzZs9G0aVO4ublh4MCBOHPmjCry/+tf/0Lfvn1Rv3591K9fHwMHDiw3fvz48dDpdDav4OBgVeRfvXp1uWyurq42Y9S8//v3718uv06nw9ChQ61janL/79+/H8OGDYOPjw90Oh22bdtW4ZyUlBR069YNkiShXbt2WL16dbkxd/s7VSFBRJq3YcMG4eLiIj7++GNx8uRJ8cwzz4h69eqJ3Nxcu+O//fZbodfrxTvvvCPS09PFG2+8IerUqSN+/PFH65i4uDhhNBrFtm3bxPfffy9CQkJE69atRVFRkeL5n3jiCbF06VJx4sQJcerUKTF+/HhhNBpFdna2dUxERIQIDg4WFy9etL7++OOPKs9emfzx8fHCYDDYZLt06ZLNGDXv/7y8PJvsaWlpQq/Xi/j4eOuYmtz/u3btEq+//rrYsmWLACC2bt0qO/6XX34R7u7uYvr06SI9PV18+OGHQq/Xiz179ljH3O0+uRMsyES1QI8ePcQLL7xgfV9aWip8fHxEbGys3fGjR48WQ4cOtWnr2bOnmDRpkhBCiLKyMtGkSROxYMECa//Vq1eFJEni008/VTz/7UpKSkTdunXFmjVrrG0RERFi+PDhVR3VrrvNHx8fL4xGo8Pt3W/7f/HixaJu3bri2rVr1raa3P9/dScFeebMmaJTp042baGhoSIoKMj6/l73iT08ZE2kccXFxTh27BgGDhxobXNycsLAgQORmppqd05qaqrNeAAICgqyjj937hwuXbpkM8ZoNKJnz54Ot1mT+W9XWFiImzdvWh+YcUtKSgoaNWoEPz8/PPfcc9YnXFWlyua/du0aWrZsCZPJhOHDh+PkyZPWvvtt/69atQphYWHw8PCwaa+J/V8ZFf38V8U+sYcFmUjjfv/9d5SWlqJx48Y27Y0bN8alS5fszrl06ZLs+Fv/vJttVlZl8t/ulVdegY+Pj80foMHBwVi7di2++OILzJ8/H1999RUGDx6M0tJSxfP7+fnh448/RmJiIj755BOUlZWhd+/eyM7OBnB/7f/Dhw8jLS0NEydOtGmvqf1fGY5+/s1mM4qKiqrkZ9IePn6RiDQtLi4OGzZsQEpKis2FUWFhYdZ/79y5M7p06YK2bdsiJSUFjz76qBJRrQICAhAQEGB937t3bzzwwANYsWIF5s6dq2Cyu7dq1Sp07twZPXr0sGlX8/5XClfIRBrn7e0NvV6P3Nxcm/bc3Fw0adLE7pwmTZrIjr/1z7vZZmVVJv8tCxcuRFxcHPbu3YsuXbrIjm3Tpg28vb3x888/33Pmv7qX/LfUqVMHXbt2tWa7X/b/9evXsWHDBkyYMKHCz6mu/V8Zjn7+DQYD3NzcquT/qT0syEQa5+Ligu7du+OLL76wtpWVleGLL76wWYX9VUBAgM14ANi3b591fOvWrdGkSRObMWazGYcOHXK4zZrMDwDvvPMO5s6diz179uDhhx+u8HOys7ORl5eHpk2bVknuWyqb/69KS0vx448/WrPdD/sf+PPWOYvFgnHjxlX4OdW1/yujop//qvh/alelLwcjovvGhg0bhCRJYvXq1SI9PV1ERkaKevXqWW+lefLJJ8Wrr75qHf/tt98KZ2dnsXDhQnHq1CkRHR1t97anevXqicTERPHDDz+I4cOHV+ttN3eTPy4uTri4uIhNmzbZ3FZTUFAghBCioKBAvPTSSyI1NVWcO3dOJCcni27duon27duLGzduKJ4/JiZGJCUlibNnz4pjx46JsLAw4erqKk6ePGnz36jW/X9Lnz59RGhoaLn2mt7/BQUF4sSJE+LEiRMCgFi0aJE4ceKEOH/+vBBCiFdffVU8+eST1vG3bnt6+eWXxalTp8TSpUvt3vYkt08qgwWZqJb48MMPRYsWLYSLi4vo0aOHOHjwoLUvMDBQRERE2Iz/7LPPhK+vr3BxcRGdOnUS//nPf2z6y8rKxKxZs0Tjxo2FJEni0UcfFT/99JMq8rds2VIAKPeKjo4WQghRWFgoBg0aJBo2bCjq1KkjWrZsKZ555pl7+sO0KvNPmzbNOrZx48ZiyJAh4vjx4zbbU/P+F0KI06dPCwBi79695bZV0/v/yy+/tPvzcCtzRESECAwMLDfnoYceEi4uLqJNmzY291DfIrdPKoOPXyQiIlIBnkMmIiJSARZkIiIiFWBBJiIiUgEWZCIiIhVgQSYiIlIBFmQiIiIVYEEmIiJSARZkIiIiFWBBJiIiUgEWZCIiIhVgQSYiIlIBFmQiIiIV+P+u3XndsGhz1gAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 600x700 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "dataiter = iter(trainloader)\n",
    "images, labels = next(dataiter)\n",
    "images.resize_(128, 784)\n",
    "\n",
    "# Forward pass through the network\n",
    "img_idx = 0\n",
    "logits = model.forward(images)\n",
    "\n",
    "# Predict the class from the network output\n",
    "prediction = F.softmax(logits, dim=1)\n",
    "\n",
    "img = images[0].data\n",
    "view_classification(img.reshape(1, 28, 28), prediction[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset MNIST\n",
       "    Number of datapoints: 60000\n",
       "    Root location: MNIST_data/\n",
       "    Split: Train\n",
       "    StandardTransform\n",
       "Transform: Compose(\n",
       "               ToTensor()\n",
       "               Normalize(mean=0.5, std=0.5)\n",
       "           )"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainloader.dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Back Propagation\n",
    "We need to now train the network to adjust its weights by first calculating Cross Entropy Loss and then back propagating the error to adjust weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create an optimizer to train the network by carrying out back propagation\n",
    "model = NN()\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
    "loss_fn = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1/1 Loss: 2.0367 Test accuracy: 0.5667\n",
      "Epoch: 1/1 Loss: 1.3270 Test accuracy: 0.7145\n",
      "Epoch: 1/1 Loss: 0.8737 Test accuracy: 0.7718\n",
      "Epoch: 1/1 Loss: 0.7358 Test accuracy: 0.8041\n",
      "Epoch: 1/1 Loss: 0.6454 Test accuracy: 0.8501\n",
      "Epoch: 1/1 Loss: 0.5656 Test accuracy: 0.8515\n",
      "Epoch: 1/1 Loss: 0.5023 Test accuracy: 0.8678\n",
      "Epoch: 1/1 Loss: 0.4775 Test accuracy: 0.8712\n",
      "Epoch: 1/1 Loss: 0.4399 Test accuracy: 0.8881\n",
      "Epoch: 1/1 Loss: 0.3717 Test accuracy: 0.8925\n",
      "Epoch: 1/1 Loss: 0.3842 Test accuracy: 0.8876\n",
      "Epoch: 1/1 Loss: 0.3793 Test accuracy: 0.8912\n",
      "Epoch: 1/1 Loss: 0.3903 Test accuracy: 0.8866\n",
      "Epoch: 1/1 Loss: 0.3750 Test accuracy: 0.8961\n",
      "Epoch: 1/1 Loss: 0.3636 Test accuracy: 0.8948\n",
      "Epoch: 1/1 Loss: 0.3437 Test accuracy: 0.9005\n",
      "Epoch: 1/1 Loss: 0.3402 Test accuracy: 0.8993\n",
      "Epoch: 1/1 Loss: 0.3325 Test accuracy: 0.9020\n",
      "Epoch: 1/1 Loss: 0.3401 Test accuracy: 0.8968\n",
      "Epoch: 1/1 Loss: 0.3009 Test accuracy: 0.9013\n",
      "Epoch: 1/1 Loss: 0.3642 Test accuracy: 0.9043\n",
      "Epoch: 1/1 Loss: 0.3764 Test accuracy: 0.9129\n",
      "Epoch: 1/1 Loss: 0.3457 Test accuracy: 0.9108\n",
      "Epoch: 1/1 Loss: 0.3554 Test accuracy: 0.9101\n",
      "Epoch: 1/1 Loss: 0.3252 Test accuracy: 0.9035\n",
      "Epoch: 1/1 Loss: 0.3303 Test accuracy: 0.9159\n",
      "Epoch: 1/1 Loss: 0.2999 Test accuracy: 0.9167\n",
      "Epoch: 1/1 Loss: 0.2423 Test accuracy: 0.9173\n",
      "Epoch: 1/1 Loss: 0.2425 Test accuracy: 0.9165\n",
      "Epoch: 1/1 Loss: 0.2315 Test accuracy: 0.9189\n",
      "Epoch: 1/1 Loss: 0.2647 Test accuracy: 0.9185\n",
      "Epoch: 1/1 Loss: 0.3319 Test accuracy: 0.9181\n",
      "Epoch: 1/1 Loss: 0.2910 Test accuracy: 0.9198\n",
      "Epoch: 1/1 Loss: 0.2430 Test accuracy: 0.9242\n",
      "Epoch: 1/1 Loss: 0.2886 Test accuracy: 0.9155\n",
      "Epoch: 1/1 Loss: 0.3167 Test accuracy: 0.9222\n",
      "Epoch: 1/1 Loss: 0.2841 Test accuracy: 0.9282\n",
      "Epoch: 1/1 Loss: 0.2755 Test accuracy: 0.9272\n",
      "Epoch: 1/1 Loss: 0.3069 Test accuracy: 0.9235\n",
      "Epoch: 1/1 Loss: 0.3281 Test accuracy: 0.9162\n",
      "Epoch: 1/1 Loss: 0.2596 Test accuracy: 0.9238\n",
      "Epoch: 1/1 Loss: 0.2806 Test accuracy: 0.9303\n",
      "Epoch: 1/1 Loss: 0.2529 Test accuracy: 0.9303\n",
      "Epoch: 1/1 Loss: 0.3011 Test accuracy: 0.9291\n",
      "Epoch: 1/1 Loss: 0.2588 Test accuracy: 0.9297\n",
      "Epoch: 1/1 Loss: 0.2408 Test accuracy: 0.9257\n"
     ]
    }
   ],
   "source": [
    "# Train network\n",
    "\n",
    "epochs = 1\n",
    "steps = 0\n",
    "running_loss = 0\n",
    "eval_freq = 10\n",
    "for e in range(epochs):\n",
    "    for images, labels in iter(trainloader):\n",
    "        steps += 1\n",
    "        images.resize_(images.size()[0], 784)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        output = model.forward(images)\n",
    "        loss = loss_fn(output, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        running_loss += loss.item()\n",
    "        \n",
    "        if steps % eval_freq == 0:\n",
    "            # Test accuracy\n",
    "            accuracy = 0\n",
    "            for ii, (images, labels) in enumerate(testloader):\n",
    "                \n",
    "                images = images.resize_(images.size()[0], 784)                \n",
    "                predicted = model.predict(images).data\n",
    "                equality = (labels == predicted.max(1)[1])\n",
    "                accuracy += equality.type_as(torch.FloatTensor()).mean()\n",
    "            \n",
    "            print(\"Epoch: {}/{}\".format(e+1, epochs),\n",
    "                  \"Loss: {:.4f}\".format(running_loss/eval_freq),\n",
    "                  \"Test accuracy: {:.4f}\".format(accuracy/(ii+1)))\n",
    "            running_loss = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeQAAAERCAYAAACq8dRTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAlc0lEQVR4nO3deVxU9f7H8fcAMoDKqOCG4pILlkuZpeE1tZs/Rbti+XOhTbxptqn50DYqI/UWll5tcSnvzzAN07KrVtfS6Je3RVwqW9zSTA3cMksGBUeB8/ujn3MbmRmQgDkeXs/H4zxqvsvMuxP48XuWOTbDMAwBAICACgp0AAAAQEEGAMAUKMgAAJgABRkAABOgIAMAYAIUZAAATICCDACACVCQAQAwAQoyAAAmQEEGgACy2WwaO3Zshb3fokWLZLPZ9Pnnn5c6tnfv3urdu7f79f79+2Wz2bRo0SJ325NPPimbzVZh+eAbBRkAvDhX2M5tYWFhatu2rcaOHaujR48GOl5APf3001q1alWgY1gOBRkA/Jg6daqWLFmiOXPmqHv37po/f77i4+OVn58f6Gh/2Lp167Ru3Tq/Yx5//HEVFBR4tFGQK0dIoAMAgJn1799fV111lSRp9OjRioqK0qxZs7R69WrdfPPNJcafOnVKNWvWrOqY5RIaGlrqmJCQEIWEUCqqAitkALgAf/7znyVJ+/bt08iRI1WrVi3t3btXAwYMUO3atXXrrbdK+q0wT5o0SbGxsbLb7YqLi9PMmTPl6wF7GRkZiouLU1hYmLp06aKPP/7Yo//AgQO69957FRcXp/DwcEVFRWno0KHav3+/1/fLz8/XXXfdpaioKEVGRmrEiBH69ddfPcacfw7Zm/PPIdtsNp06dUqvvvqq+3D+yJEj9dFHH8lms2nlypUl3mPp0qWy2WzKysry+1nVHX/tAYALsHfvXklSVFSUJKmwsFD9+vVTjx49NHPmTEVERMgwDCUmJuqjjz7SqFGjdMUVV2jt2rV68MEHdfDgQc2ePdvjPf/9739r+fLlGj9+vOx2u+bNm6eEhARt3rxZHTp0kCRt2bJFGzZsUFJSkpo2bar9+/dr/vz56t27t3bs2KGIiAiP9xw7dqzq1KmjJ598Ut99953mz5+vAwcOaP369X/oIq0lS5Zo9OjR6tq1q8aMGSNJatWqla655hrFxsYqIyNDN910k8ecjIwMtWrVSvHx8eX+3GrBAACUkJ6ebkgyMjMzjWPHjhnZ2dnGsmXLjKioKCM8PNzIyckxkpOTDUnGI4884jF31apVhiTjb3/7m0f7kCFDDJvNZnz//ffuNkmGJOPzzz93tx04cMAICwszbrrpJndbfn5+iYxZWVmGJGPx4sUlcnfp0sU4c+aMu/3ZZ581JBmrV692t/Xq1cvo1auX+/W+ffsMSUZ6erq7LTU11Ti/VNSsWdNITk4ukSclJcWw2+3GiRMn3G0//fSTERISYqSmppYYD08csgYAP/r06aP69esrNjZWSUlJqlWrllauXKkmTZq4x9xzzz0ec9asWaPg4GCNHz/eo33SpEkyDEPvvfeeR3t8fLy6dOnift2sWTMNGjRIa9euVVFRkSQpPDzc3X/27FkdP35crVu3Vp06dfTll1+WyD1mzBjVqFHDI2NISIjWrFlTjr1QNiNGjJDL5dKKFSvcbcuXL1dhYaFuu+22Svtcq+CQNQD4MXfuXLVt21YhISFq2LCh4uLiFBT0n7VMSEiImjZt6jHnwIEDiomJUe3atT3aL730Unf/77Vp06bE57Zt21b5+fk6duyYGjVqpIKCAqWlpSk9PV0HDx70OBedm5tbYv7571mrVi01btzY5znnitCuXTtdffXVysjI0KhRoyT9drj6mmuuUevWrSvtc62CggwAfnTt2tV9lbU3drvdo0BXlnHjxik9PV0TJkxQfHy8HA6HbDabkpKSVFxcXOmfX1YjRozQ/fffr5ycHLlcLm3cuFFz5swJdKyLAgUZACpY8+bNlZmZqby8PI9V8q5du9z9v7dnz54S77F7925FRESofv36kqQVK1YoOTlZf//7391jTp8+rRMnTnjNsGfPHl133XXu1ydPntThw4c1YMCAcv93nePvorCkpCRNnDhRr7/+ugoKClSjRg0NHz78D39mdcA5ZACoYAMGDFBRUVGJleHs2bNls9nUv39/j/asrCyP88DZ2dlavXq1+vbtq+DgYElScHBwiVumXnzxRfc55vMtWLBAZ8+edb+eP3++CgsLS3x2edSsWdPnXwSio6PVv39/vfbaa8rIyFBCQoKio6P/8GdWB6yQAaCCDRw4UNddd50ee+wx7d+/X5dffrnWrVun1atXa8KECWrVqpXH+A4dOqhfv34etz1J0pQpU9xj/vKXv2jJkiVyOBy67LLLlJWVpczMTPftV+c7c+aMrr/+eg0bNkzfffed5s2bpx49eigxMfEP//d16dJFmZmZmjVrlmJiYtSyZUt169bN3T9ixAgNGTJEkjRt2rQ//HnVBQUZACpYUFCQ3n77bT3xxBNavny50tPT1aJFC82YMUOTJk0qMb5Xr16Kj4/XlClT9OOPP+qyyy7TokWL1KlTJ/eY559/XsHBwcrIyNDp06f1pz/9SZmZmerXr5/XDHPmzFFGRoaeeOIJnT17VjfffLNeeOGFCnlQxKxZszRmzBj312omJyd7FOSBAweqbt26Ki4urpC/AFQXNuP8YyAAAPwBhYWFiomJ0cCBA7Vw4cJAx7locA4ZAFChVq1apWPHjmnEiBGBjnJRYYUMAKgQmzZt0jfffKNp06YpOjra6xeWwDdWyACACjF//nzdc889atCggRYvXhzoOBcdVsgAAJgAK2QAAEygzLc9/VfQ0MrMAVQrHxS/GegIAEyG+5ABeFVcXKxDhw6pdu3aFXLvKlBdGYahvLw8xcTE+P3ecwoyAK8OHTqk2NjYQMcALCM7O7vEk8F+j4IMwKtzD0XIzs5WZGRkgNMAFy+n06nY2NgSj+M8HwUZgFfnDlNHRkZSkIEKUNqpH66yBgDABCjIAACYAAUZAAAToCADAGACFGQAAEyAggwAgAlQkAEAMAEKMgAAJkBBBizmiy++UEJCgiIjI1W7dm317dtXX331VaBjASgF39SFCpE3/Bq//fOmP++z76Fbx/ida9vwdbkyVUdffvmlevToodjYWKWmpqq4uFjz5s1Tr169tHnzZsXFxQU6IgAfKMiAhUyePFnh4eHKyspSVFSUJOm2225T27Zt9eijj+qtt94KcEIAvnDIGrCQTz75RH369HEXY0lq3LixevXqpXfffVcnT54MYDoA/lCQAQtxuVwKDw8v0R4REaEzZ85o27ZtAUgFoCw4ZA1YSFxcnDZu3KiioiIFBwdLks6cOaNNmzZJkg4ePOhzrsvlksvlcr92Op2VGxaAB1bIgIXce++92r17t0aNGqUdO3Zo27ZtGjFihA4fPixJKigo8Dk3LS1NDofDvcXGxlZVbACiIAOWcvfdd+vRRx/V0qVL1b59e3Xs2FF79+7VQw89JEmqVauWz7kpKSnKzc11b9nZ2VUVG4AoyIDlPPXUUzp69Kg++eQTffPNN9qyZYuKi4slSW3btvU5z263KzIy0mMDUHU4h4wyC27bymffpGlL/c7tGFrDZ9/xjhF+50Zv8J8LJdWtW1c9evRwv87MzFTTpk3Vrl27AKYC4A8rZMDili9fri1btmjChAkKCuJXHjArVsiAhXz88ceaOnWq+vbtq6ioKG3cuFHp6elKSEjQ/fffH+h4APygIAMW0qRJEwUHB2vGjBnKy8tTy5Yt9be//U0TJ05USAi/7oCZ8RsKWEirVq20du3aQMcAUA6cUAIAwAQoyAAAmACHrFFme5Mb+Oy7seYJv3Nbv+/7EYtxC7f4nWv47QUAa2CFDACACVCQAQAwAQoyAAAmQEEGLGjPnj1KSkpS06ZNFRERoXbt2mnq1KnKz88PdDQAPnBRF2Ax2dnZ6tq1qxwOh8aOHat69eopKytLqamp+uKLL7R69epARwTgBQUZsJglS5boxIkT+vTTT9W+fXtJ0pgxY1RcXKzFixfr119/Vd26dQOcEsD5OGQNWIzT6ZQkNWzY0KO9cePGCgoKUmhoaCBiASgFK+RqJqRJjM++U+n+/6D+oN0Mn32d5jzkd27cjM0++4zCQr9zcWF69+6tZ555RqNGjdKUKVMUFRWlDRs2aP78+Ro/frxq1qzpdZ7L5ZLL5XK/PlfYAVQNVsiAxSQkJGjatGn64IMP1LlzZzVr1kxJSUkaN26cZs+e7XNeWlqaHA6He4uNja3C1AAoyIAFtWjRQj179tSCBQv01ltv6Y477tDTTz+tOXPm+JyTkpKi3Nxc95adnV2FiQFwyBqwmGXLlmnMmDHavXu3mjZtKkkaPHiwiouL9fDDD+vmm29WVFRUiXl2u112u72q4wL4f6yQAYuZN2+eOnfu7C7G5yQmJio/P19bt24NUDIA/lCQAYs5evSoioqKSrSfPXtWklTIRXSAKVGQAYtp27attm7dqt27d3u0v/766woKClKnTp0ClAyAP5xDrmb2jGvus29n+7l+58atH+ezr1XaBr9zeYRi1XnwwQf13nvv6dprr9XYsWMVFRWld999V++9955Gjx6tmBjft74BCBwKMmAxPXv21IYNG/Tkk09q3rx5On78uFq2bKmnnnpKDz3k/35xAIFDQQYsqGvXrlqzZk2gYwC4AJxDBgDABCjIAACYAAUZAAAToCADAGACFGQAAEyAq6wtxoi/3G//tb2/9dk385c4v3PjHjjss4/vfjKPkSNH6tVXX/XZn5OToyZNmlRhIgBlQUEGLOauu+5Snz59PNoMw9Ddd9+tFi1aUIwBk6IgAxYTHx+v+Ph4j7ZPP/1U+fn5uvXWWwOUCkBpOIcMVANLly6VzWbTLbfcEugoAHygIAMWd/bsWb3xxhvq3r27WrRoEeg4AHzgkDVgcWvXrtXx48dLPVztcrnkcrncr51OZ2VHA/A7rJABi1u6dKlq1KihYcOG+R2XlpYmh8Ph3mJjY6soIQCJggxY2smTJ7V69Wr169dPUVFRfsempKQoNzfXvWVnZ1dRSgASh6wvOja73W9/mxd2+e1/rOGHPvuG3T/J79yIw5v89sN8Vq1aVearq+12u+yl/HwBqDyskAELy8jIUK1atZSYmBjoKABKQUEGLOrYsWPKzMzUTTfdpIiIiEDHAVAKCjJgUcuXL1dhYSFfBgJcJCjIgEVlZGSoQYMGJb5GE4A5cVEXYFFZWVmBjgDgArBCBgDABFghX2T2PXGl3/5/xcz12z8mO8FnX6112/zOLfbbCwD4I1ghAwBgAqyQAfjVIXWtguzcNoXqY//0GwLyuayQAQAwAQoyYFFffvmlEhMTVa9ePUVERKhDhw564YUXAh0LgA8csgYsaN26dRo4cKA6d+6syZMnq1atWtq7d69ycnICHQ2ADxRkwGKcTqdGjBihG264QStWrFBQEAfCgIsBv6mAxSxdulRHjx7VU089paCgIJ06dUrFxdy0BpgdK2QTCo6q57PvHzfP9zvXZRT67T90dzOffcWndvgPhotCZmamIiMjdfDgQd14443avXu3atasqdtvv12zZ89WWFhYoCMC8IIVMmAxe/bsUWFhoQYNGqR+/frprbfe0h133KGXXnpJf/3rX33Oc7lccjqdHhuAqsMKGbCYkydPKj8/X3fffbf7qurBgwfrzJkzevnllzV16lS1adOmxLy0tDRNmTKlquMC+H+skAGLCQ8PlyTdfPPNHu233HKLJN8PnUhJSVFubq57y87OrtygADywQgYsJiYmRtu3b1fDhg092hs0aCBJ+vXXX73Os9vtstvtlZ4PgHeskAGL6dKliyTp4MGDHu2HDh2SJNWvX7/KMwEoHQUZsJhhw4ZJkhYuXOjR/j//8z8KCQlR7969A5AKQGk4ZG1CPy+O8tn3J7v/+0kvXXy/3/6WX/HQeqvr3Lmz7rjjDr3yyisqLCxUr169tH79er355ptKSUlRTExMoCMC8IKCDFjQSy+9pGbNmik9PV0rV65U8+bNNXv2bE2YMCHQ0QD4QEEGLKhGjRpKTU1VampqoKMAKCMKMgC/tk3pp8jIyEDHACyPi7oAADABCjIAACZAQQYAwAQoyAAAmAAXdQXA0fHd/fZvuPw5n333Huzld26rp7f57eepuNa3fv16XXfddV77srKydM0111RxIgBlQUEGLGr8+PG6+uqrPdpat24doDQASkNBBizq2muv1ZAhQwIdA0AZcQ4ZsLC8vDwVFhYGOgaAMqAgAxb117/+VZGRkQoLC9N1112nzz//PNCRAPjBIWvAYkJDQ/Xf//3fGjBggKKjo7Vjxw7NnDlT1157rTZs2KDOnTt7nedyueRyudyvnU5nVUUGIAoyYDndu3dX9+7/uZI/MTFRQ4YMUadOnZSSkqL333/f67y0tDRNmTKlqmICOA8FuZIU9b7SZ99LE170O9du8/2/Zde0jn7nhuVt9h8M1VLr1q01aNAg/fOf/1RRUZGCg4NLjElJSdHEiRPdr51Op2JjY6syJlCtUZCBaiI2NlZnzpzRqVOnvD4swm63y263ByAZAImLuoBq44cfflBYWJhq1aoV6CgAvKAgAxZz7NixEm1ff/213n77bfXt21dBQfzaA2bEIWvAYoYPH67w8HB1795dDRo00I4dO7RgwQJFRERo+vTpgY4HwAcKMmAxN954ozIyMjRr1iw5nU7Vr19fgwcPVmpqKl+dCZgYBRmwmPHjx2v8+PGBjgHgAnEyCQAAE2CFXEn2/yXUZ9/VdpvfuZel3+ezr8U7WeXOVJrgqHp++wvjfN+TGrJjv9+5RSdyyxMJAKoNVsgAAJgABRkAABOgIAMAYAIUZAAATICCDFjcU089JZvNpg4dOgQ6CgA/KMiAheXk5Ojpp59WzZo1Ax0FQCm47QmwsAceeEDXXHONioqK9PPPPwc6DgA/KMjlVNzjCr/9zyQu9dn3fkGE37mtZu7w2Wd0aud37t6kun77g8/4vgf61sH/63duSlSmz777D8X7nbvnar/dqAQff/yxVqxYoa1bt2rcuHGBjgOgFBRkwIKKioo0btw4jR49Wh07dizTHJfLJZfL5X7tdDorKx4ALyjIgAW99NJLOnDggDIzfR/VOF9aWpqmTJlSiakA+MNFXYDFHD9+XE888YQmT56s+vXrl3leSkqKcnNz3Vt2dnYlpgRwPlbIgMU8/vjjqlev3gWfN7bb7bLb7ZWUCkBpKMiAhezZs0cLFizQc889p0OHDrnbT58+rbNnz2r//v2KjIxUvXr+HyQCoOpxyBqwkIMHD6q4uFjjx49Xy5Yt3dumTZu0e/dutWzZUlOnTg10TABesEIupyPX+L916caaJ3z2PfrTlX7n/vhKU599a6562e/cJsH+c1WWS2se8tu/R1FVlKR669Chg1auXFmi/fHHH1deXp6ef/55tWrVKgDJAJSGggxYSHR0tG688cYS7c8995wkee0DYA4csgYAwARYIQPVwPr16wMdAUApWCEDAGACFGQAAEyAggwAgAlQkAEAMAEu6iqnmn/+qdxzn27wZbn7B+0Z6nduw7A8v/0b1nTy2XfJa/7vJba94vLZ17/BNr9zAQD+sUIGLGb79u0aOnSoLrnkEkVERCg6Olo9e/bUO++8E+hoAPxghQxYzIEDB5SXl6fk5GTFxMQoPz9fb731lhITE/Xyyy9rzJgxgY4IwAsKMmAxAwYM0IABAzzaxo4dqy5dumjWrFkUZMCkOGQNVAPBwcGKjY3ViRMnAh0FgA+skAGLOnXqlAoKCpSbm6u3335b7733noYPHx7oWAB8oCADFjVp0iS9/PJvTwcLCgrS4MGDNWfOHJ/jXS6XXK7/XEnvdDorPSOA/6Agl5PNZlTae1+55VaffU0nFvidm/PDYb/9lzTa67Pv+3GX+J27o81cn31tl9/rd25rbfTbj4o3YcIEDRkyRIcOHdIbb7yhoqIinTlzxuf4tLQ0TZkypQoTAvg9ziEDFtWuXTv16dNHI0aM0LvvvquTJ09q4MCBMgzvf5lMSUlRbm6ue8vOzq7ixED1RkEGqokhQ4Zoy5Yt2r17t9d+u92uyMhIjw1A1aEgA9VEQcFvpztyc3MDnASANxRkwGJ++qnk17qePXtWixcvVnh4uC677LIApAJQGi7qAizmrrvuktPpVM+ePdWkSRMdOXJEGRkZ2rVrl/7+97+rVq1agY4IwAsKMmAxw4cP18KFCzV//nwdP35ctWvXVpcuXfTMM88oMTEx0PEA+EBBBiwmKSlJSUlJgY4B4AJRkMvp528b+B/g+ymHpTqVb/fZt+fOOn7nFoU38tv/6eCZPvsaBEf4nevvXuPWE7nPGAD+CC7qAgDABCjIAACYAAUZAAAToCADAGACFGTAYrZs2aKxY8eqffv2qlmzppo1a6Zhw4b5/MpMAObAVdaAxTzzzDP67LPPNHToUHXq1ElHjhzRnDlzdOWVV2rjxo3q0KFDoCMC8IKCXE5N/7fQ/wDfT1AsVUrn9332jep9xO/cIqPYb/+3Z2r47Ov5+n1+57Z+OMtvP8xh4sSJWrp0qUJDQ91tw4cPV8eOHTV9+nS99tprAUwHwBcKMmAx3bt3L9HWpk0btW/fXjt37gxAIgBlwTlkoBowDENHjx5VdHR0oKMA8IGCDFQDGRkZOnjwoIYPH+5zjMvlktPp9NgAVB0KMmBxu3bt0n333af4+HglJyf7HJeWliaHw+HeYmNjqzAlAAoyYGFHjhzRDTfcIIfDoRUrVig4ONjn2JSUFOXm5rq37OzsKkwKgIu6AIvKzc1V//79deLECX3yySeKiYnxO95ut8tu9/1gEwCVi4IMWNDp06c1cOBA7d69W5mZmbrssssCHQlAKSjI5RT2yQ6//e1e831Pb51d/t978qOv+uwr7T5jf58rSTGfFvnsa/2/3/id6/+TYRZFRUUaPny4srKytHr1asXHxwc6EoAyoCADFjNp0iS9/fbbGjhwoH755ZcSXwRy2223BSgZAH8oyIDFfPXVV5Kkd955R++8806JfgoyYE4UZMBi1q9fH+gIAMqB254AADABCjIAACZAQQYAwAQoyAAAmIDNMAyjLAP/K2hoZWcBqo0Pit8MdIRSOZ1OORwO5ebmKjIyMtBxgItWWX+XWCEDFnPy5EmlpqYqISFB9erVk81m06JFiwIdC0ApKMiAxfz888+aOnWqdu7cqcsvvzzQcQCUEfchAxbTuHFjHT58WI0aNdLnn3+uq6++OtCRAJQBK2TAYux2uxo1ahToGAAuEAUZAAAT4JA1AEmSy+WSy+Vyv3Y6nQFMA1Q/rJABSJLS0tLkcDjcW2xsbKAjAdUKBRmAJCklJUW5ubnuLTs7O9CRgGqFQ9YAJP12MZjdbg90DKDaYoUMAIAJUJABADABDlkDFjRnzhydOHFChw4dkiS98847ysnJkSSNGzdODocjkPEAeMHDJYAAqOyHS7Ro0UIHDhzw2rdv3z61aNGi1Pfg4RJAxSjr7xIrZMCC9u/fH+gIAC4Q55ABADABCjIAACZAQQYAwAQoyAAAmAAFGQAAE6AgAxbkcrn08MMPKyYmRuHh4erWrZs++OCDQMcC4AcFGbCgkSNHatasWbr11lv1/PPPKzg4WAMGDNCnn34a6GgAfOA+ZMBiNm/erGXLlmnGjBl64IEHJEkjRoxQhw4d9NBDD2nDhg0BTgjAG1bIgMWsWLFCwcHBGjNmjLstLCxMo0aNUlZWFo9VBEyKggxYzNatW9W2bdsSX9HXtWtXSdJXX30VgFQASsMha8BiDh8+rMaNG5doP9d27oET53O5XHK5XO7XTqezcgIC8IoVMmAxBQUFstvtJdrDwsLc/d6kpaXJ4XC4t9jY2ErNCcATBRmwmPDwcI+V7jmnT59293uTkpKi3Nxc98a5ZqBqccgasJjGjRvr4MGDJdoPHz4sSYqJifE6z263e11ZA6garJABi7niiiu0e/fuEueAN23a5O4HYD4UZMBihgwZoqKiIi1YsMDd5nK5lJ6erm7dunFuGDApDlkDFtOtWzcNHTpUKSkp+umnn9S6dWu9+uqr2r9/vxYuXBjoeAB8oCADFrR48WJNnjxZS5Ys0a+//qpOnTrp3XffVc+ePQMdDYAPNsMwjLIM/K+goZWdBag2Pih+M9ARSuV0OuVwOJSbm1viS0YAlF1Zf5c4hwwAgAlQkAEAMAEKMgAAJkBBBgDABCjIAACYAAUZAAAToCADAGACFGQAAEyAb+oC4NW57ww6/yEVAC7Mud+h0r6Hi4IMwKvjx49LEg+jACpIXl6eHA6Hz34KMgCv6tWrJ0n68ccf/f4hYlZOp1OxsbHKzs6+KL/6k/yBVZH5DcNQXl6ez2eRn1PmgnwxfPcugIoTFPTbJSYOh+Oi/AP1nMjISPIHEPl/U5a/1HJRFwAAJkBBBgDABCjIALyy2+1KTU2V3W4PdJRyIX9gkf/Clfl5yAAAoPKwQgYAwAQoyAAAmAAFGQAAE6AgAwBgAhRkoJqYO3euWrRoobCwMHXr1k2bN2/2O/7NN99Uu3btFBYWpo4dO2rNmjUe/YZh6IknnlDjxo0VHh6uPn36aM+ePabI/49//EPXXnut6tatq7p166pPnz4lxo8cOVI2m81jS0hIMEX+RYsWlcgWFhbmMcbM+793794l8ttsNt1www3uMVW5/z/++GMNHDhQMTExstlsWrVqValz1q9fryuvvFJ2u12tW7fWokWLSoy50N+pUhkALG/ZsmVGaGio8corrxjbt2837rzzTqNOnTrG0aNHvY7/7LPPjODgYOPZZ581duzYYTz++ONGjRo1jG+//dY9Zvr06YbD4TBWrVplfP3110ZiYqLRsmVLo6CgIOD5b7nlFmPu3LnG1q1bjZ07dxojR440HA6HkZOT4x6TnJxsJCQkGIcPH3Zvv/zyS4VnL0/+9PR0IzIy0iPbkSNHPMaYef8fP37cI/u2bduM4OBgIz093T2mKvf/mjVrjMcee8z45z//aUgyVq5c6Xf8Dz/8YERERBgTJ040duzYYbz44otGcHCw8f7777vHXOg+KQsKMlANdO3a1bjvvvvcr4uKioyYmBgjLS3N6/hhw4YZN9xwg0dbt27djLvuusswDMMoLi42GjVqZMyYMcPdf+LECcNutxuvv/56wPOfr7Cw0Khdu7bx6quvutuSk5ONQYMGVXRUry40f3p6uuFwOHy+38W2/2fPnm3Url3bOHnypLutKvf/75WlID/00ENG+/btPdqGDx9u9OvXz/36j+4TbzhkDVjcmTNn9MUXX6hPnz7utqCgIPXp00dZWVle52RlZXmMl6R+/fq5x+/bt09HjhzxGONwONStWzef71mV+c+Xn5+vs2fPuh+Ycc769evVoEEDxcXF6Z577nE/4aoilTf/yZMn1bx5c8XGxmrQoEHavn27u+9i2/8LFy5UUlKSatas6dFeFfu/PEr7+a+IfeINBRmwuJ9//llFRUVq2LChR3vDhg115MgRr3OOHDnid/y5f17Ie5ZXefKf7+GHH1ZMTIzHH6AJCQlavHixPvzwQz3zzDP697//rf79+6uoqCjg+ePi4vTKK69o9erVeu2111RcXKzu3bsrJydH0sW1/zdv3qxt27Zp9OjRHu1Vtf/Lw9fPv9PpVEFBQYX8THrD4xcBWNr06dO1bNkyrV+/3uPCqKSkJPe/d+zYUZ06dVKrVq20fv16XX/99YGI6hYfH6/4+Hj36+7du+vSSy/Vyy+/rGnTpgUw2YVbuHChOnbsqK5du3q0m3n/BworZMDioqOjFRwcrKNHj3q0Hz16VI0aNfI6p1GjRn7Hn/vnhbxneZUn/zkzZ87U9OnTtW7dOnXq1Mnv2EsuuUTR0dH6/vvv/3Dm3/sj+c+pUaOGOnfu7M52sez/U6dOadmyZRo1alSpn1NZ+788fP38R0ZGKjw8vEL+n3pDQQYsLjQ0VF26dNGHH37obisuLtaHH37osQr7vfj4eI/xkvTBBx+4x7ds2VKNGjXyGON0OrVp0yaf71mV+SXp2Wef1bRp0/T+++/rqquuKvVzcnJydPz4cTVu3LhCcp9T3vy/V1RUpG+//dad7WLY/9Jvt865XC7ddtttpX5OZe3/8ijt578i/p96Ve7LwQBcNJYtW2bY7XZj0aJFxo4dO4wxY8YYderUcd9Kc/vttxuPPPKIe/xnn31mhISEGDNnzjR27txppKamer3tqU6dOsbq1auNb775xhg0aFCl3nZzIfmnT59uhIaGGitWrPC4rSYvL88wDMPIy8szHnjgASMrK8vYt2+fkZmZaVx55ZVGmzZtjNOnTwc8/5QpU4y1a9cae/fuNb744gsjKSnJCAsLM7Zv3+7x32jW/X9Ojx49jOHDh5dor+r9n5eXZ2zdutXYunWrIcmYNWuWsXXrVuPAgQOGYRjGI488Ytx+++3u8edue3rwwQeNnTt3GnPnzvV625O/fVIeFGSgmnjxxReNZs2aGaGhoUbXrl2NjRs3uvt69eplJCcne4x/4403jLZt2xqhoaFG+/btjX/9618e/cXFxcbkyZONhg0bGna73bj++uuN7777zhT5mzdvbkgqsaWmphqGYRj5+flG3759jfr16xs1atQwmjdvbtx5551/6A/Tisw/YcIE99iGDRsaAwYMML788kuP9zPz/jcMw9i1a5chyVi3bl2J96rq/f/RRx95/Xk4lzk5Odno1atXiTlXXHGFERoaalxyySUe91Cf42+flAePXwQAwAQ4hwwAgAlQkAEAMAEKMgAAJkBBBgDABCjIAACYAAUZAAAToCADAGACFGQAAEyAggwAgAlQkAEAMAEKMgAAJkBBBgDABP4PAIh3KoYJvtQAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 600x700 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "logits = model.forward(img[None,])\n",
    "\n",
    "# Predict the class from the network output\n",
    "prediction = F.softmax(logits, dim=1)\n",
    "\n",
    "view_classification(img.reshape(1, 28, 28), prediction[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that model is able to correctly predict the digit after training while before training it predicting all digits with almost equal probability i.e. it was randomly predicting the digit."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
